{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a246687d",
   "metadata": {},
   "source": [
    "# The Product Pricer\n",
    "\n",
    "A model that can estimate how much something costs, from its description\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3792ce5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "! uv -q pip install langchain-ollama"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "390c3ce3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# imports\n",
    "\n",
    "import os\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "\n",
    "from dotenv import load_dotenv\n",
    "from huggingface_hub import login\n",
    "from datasets import load_dataset, Dataset, DatasetDict\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle\n",
    "import re\n",
    "from langchain_ollama import OllamaLLM\n",
    "from openai import OpenAI\n",
    "from testing import Tester\n",
    "import json\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a8ff331",
   "metadata": {},
   "outputs": [],
   "source": [
    "load_dotenv(override=True)\n",
    "hf_token = os.getenv(\"HF_TOKEN\")\n",
    "login(hf_token, add_to_git_credential=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1051e21e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from items import Item\n",
    "from loaders import ItemLoader\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "290fa868",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_names = [\n",
    "  \"Appliances\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12ffad66",
   "metadata": {},
   "outputs": [],
   "source": [
    "items = []\n",
    "for dataset_name in dataset_names:\n",
    "    loader = ItemLoader(dataset_name)\n",
    "    items.extend(loader.load())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b3890d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"A grand total of {len(items):,} items\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "246ab22a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the distribution of token counts again\n",
    "\n",
    "tokens = [item.token_count for item in items]\n",
    "plt.figure(figsize=(15, 6))\n",
    "plt.title(f\"Token counts: Avg {sum(tokens)/len(tokens):,.1f} and highest {max(tokens):,}\\n\")\n",
    "plt.xlabel('Length (tokens)')\n",
    "plt.ylabel('Count')\n",
    "plt.hist(tokens, rwidth=0.7, color=\"skyblue\", bins=range(0, 300, 10))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a49a4d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the distribution of prices\n",
    "\n",
    "prices = [item.price for item in items]\n",
    "plt.figure(figsize=(15, 6))\n",
    "plt.title(f\"Prices: Avg {sum(prices)/len(prices):,.1f} and highest {max(prices):,}\\n\")\n",
    "plt.xlabel('Price ($)')\n",
    "plt.ylabel('Count')\n",
    "plt.hist(prices, rwidth=0.7, color=\"blueviolet\", bins=range(0, 1000, 10))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57e4ea1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# How does the price vary with the character count of the prompt?\n",
    "\n",
    "sample = items\n",
    "\n",
    "sizes = [len(item.prompt) for item in sample]\n",
    "prices = [item.price for item in sample]\n",
    "\n",
    "# Create the scatter plot\n",
    "plt.figure(figsize=(15, 8))\n",
    "plt.scatter(sizes, prices, s=0.2, color=\"red\")\n",
    "\n",
    "# Add labels and title\n",
    "plt.xlabel('Size')\n",
    "plt.ylabel('Price')\n",
    "plt.title('Is there a simple correlation?')\n",
    "\n",
    "# Display the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6620daa",
   "metadata": {},
   "outputs": [],
   "source": [
    "def report(item):\n",
    "    prompt = item.prompt\n",
    "    tokens = Item.tokenizer.encode(item.prompt)\n",
    "    print(prompt)\n",
    "    print(tokens[-10:])\n",
    "    print(Item.tokenizer.batch_decode(tokens[-10:]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af71d177",
   "metadata": {},
   "outputs": [],
   "source": [
    "report(sample[50])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75ab3c21",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "\n",
    "\n",
    "random.seed(42)\n",
    "random.shuffle(sample)\n",
    "train = sample[:25_000]\n",
    "test = sample[25_000:27_000]\n",
    "print(f\"Divided into a training set of {len(train):,} items and test set of {len(test):,} items\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d5cbd3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(train[0].prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39de86d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(test[0].test_prompt())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65480df9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the distribution of prices in the first 250 test points\n",
    "\n",
    "prices = [float(item.price) for item in test[:250]]\n",
    "plt.figure(figsize=(15, 6))\n",
    "plt.title(f\"Avg {sum(prices)/len(prices):.2f} and highest {max(prices):,.2f}\\n\")\n",
    "plt.xlabel('Price ($)')\n",
    "plt.ylabel('Count')\n",
    "plt.hist(prices, rwidth=0.7, color=\"darkblue\", bins=range(0, 1000, 10))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a315b10",
   "metadata": {},
   "outputs": [],
   "source": [
    "filtered_prices = [float(item.price) for item in test if item.price > 99.999]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5693c9c6",
   "metadata": {},
   "source": [
    "### Confirm that the tokenizer tokenizes all 3 digit prices into 1 token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99e8cfc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "for price in filtered_prices:\n",
    "    tokens = Item.tokenizer.encode(f\"{price}\", add_special_tokens=False)\n",
    "    assert len(tokens) == 3\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3159195",
   "metadata": {},
   "source": [
    "## Helpers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7bdc5dd5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def messages_for(item):\n",
    "    system_message = \"You estimate prices of items. Reply only with the price, no explanation\"\n",
    "    user_prompt = item.test_prompt().replace(\" to the nearest dollar\",\"\").replace(\"\\n\\nPrice is $\",\"\")\n",
    "    return [\n",
    "        {\"role\": \"system\", \"content\": system_message},\n",
    "        {\"role\": \"user\", \"content\": user_prompt},\n",
    "        {\"role\": \"assistant\", \"content\": \"Price is $\"}\n",
    "    ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "211b0658",
   "metadata": {},
   "outputs": [],
   "source": [
    "# A utility function to extract the price from a string\n",
    "\n",
    "def get_price(s):\n",
    "    s = s.replace('$','').replace(',','')\n",
    "    match = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", s)\n",
    "    return float(match.group()) if match else 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee01da84",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert the items into a list of json objects - a \"jsonl\" string\n",
    "# Each row represents a message in the form:\n",
    "# {\"messages\" : [{\"role\": \"system\", \"content\": \"You estimate prices...\n",
    "\n",
    "\n",
    "def make_jsonl(items):\n",
    "    result = \"\"\n",
    "    for item in items:\n",
    "        messages = messages_for(item)\n",
    "        messages_str = json.dumps(messages)\n",
    "        result += '{\"messages\": ' + messages_str +'}\\n'\n",
    "    return result.strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f23e8959",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert the items into jsonl and write them to a file\n",
    "\n",
    "def write_jsonl(items, filename):\n",
    "    with open(filename, \"w\") as f:\n",
    "        jsonl = make_jsonl(items)\n",
    "        f.write(jsonl)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6a83580",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "451b974f",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('train_lite.pkl', 'rb') as f:\n",
    "    train_lite = pickle.load(f)\n",
    "\n",
    "with open('test_lite.pkl', 'rb') as f:\n",
    "    test_lite = pickle.load(f)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f365d65c",
   "metadata": {},
   "outputs": [],
   "source": [
    "messages_for(test_lite[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57b0b160",
   "metadata": {},
   "outputs": [],
   "source": [
    "get_price(\"The price is roughly $99.99 because blah blah\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff3e4670",
   "metadata": {},
   "source": [
    "## Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f62c94b",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_LLAMA3_2 = \"llama3.2\"\n",
    "MODEL_MISTRAL = \"mistral\"\n",
    "MODEL_TINY_LLAMA = \"tinyllama\"\n",
    "\n",
    "llm3_2 = OllamaLLM(model=MODEL_LLAMA3_2)\n",
    "llmMistral = OllamaLLM(model=MODEL_MISTRAL)\n",
    "llmTinyLlama = OllamaLLM(model=MODEL_TINY_LLAMA)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d18394fb",
   "metadata": {},
   "source": [
    "## Model Tests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7dac335f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def llama3_2_model(item):\n",
    "  response = llm3_2.invoke(messages_for(item))\n",
    "  return get_price(response)\n",
    "\n",
    "def mistral_model(item):\n",
    "  response = llmMistral.invoke(messages_for(item))\n",
    "  return get_price(response)\n",
    "\n",
    "def tinyllama_model(item):\n",
    "  response = llmTinyLlama.invoke(messages_for(item))\n",
    "  return get_price(response)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "062e78c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_lite[0].price"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c58756f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "Tester.test(llama3_2_model, test_lite)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "899e2401",
   "metadata": {},
   "outputs": [],
   "source": [
    "Tester.test(mistral_model, test_lite)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f5bc9ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "Tester.test(tinyllama_model, test_lite)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
